{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Flash Evaluation on Streamspot Dataset:\n",
    "\n",
    "This notebook is dedicated to evaluating Flash on the Streamspot dataset, which are graph-level in nature. We employ Flash in graph-level detection mode to analyze this dataset effectively. Upon completion of the notebook execution, the results will be presented.\n",
    "\n",
    "## Dataset Access: \n",
    "- The Streamspot dataset can be accessed at the following link: [Streamspot Dataset](https://github.com/sbustreamspot/sbustreamspot-data).\n",
    "- Please download the required data files from the provided link.\n",
    "\n",
    "## Data Parsing and Execution:\n",
    "- Utilize the parser included in this notebook to process the downloaded files.\n",
    "- To obtain the evaluation results, execute all cells within this notebook.\n",
    "\n",
    "## Model Training and Execution Flexibility:\n",
    "- By default, the notebook operates using pre-trained model weights.\n",
    "- Additionally, this notebook offers the flexibility to set parameters for training Graph Neural Networks (GNNs) and word2vec models from scratch.\n",
    "- You can then utilize these freshly trained models to conduct the evaluation. \n",
    "\n",
    "Follow these guidelines for a thorough and efficient analysis of the Streamspot dataset using Flash."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "F1op-CbyLuN4",
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "import torch\n",
    "from torch_geometric.data import Data\n",
    "import os\n",
    "import torch.nn.functional as F\n",
    "import json\n",
    "import warnings\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.manifold import TSNE\n",
    "warnings.filterwarnings('ignore')\n",
    "from torch_geometric.loader import NeighborLoader\n",
    "import multiprocessing\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "Train_Gnn = False\n",
    "Train_Word2vec = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nM7KaeCbA_mQ",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from pprint import pprint\n",
    "import gzip\n",
    "from sklearn.manifold import TSNE\n",
    "import json\n",
    "import copy\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os.path as osp\n",
    "import csv\n",
    "def show(str):\n",
    "\tprint (str + ' ' + time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime(time.time())))\n",
    "\n",
    "def parse_data():\n",
    "    os.system('tar -zxvf all.tar.gz')\n",
    "\n",
    "    show('Start processing.')\n",
    "    data = []\n",
    "    gId = -1\n",
    "    with open('all.tsv') as f:\n",
    "        tsvreader = csv.reader(f, delimiter='\\t')\n",
    "        for row in tsvreader:\n",
    "            if int(row[5]) != gId:\n",
    "                gId = int(row[5])\n",
    "                show('Graph ' + str(gId))\n",
    "                scene = int(gId/100)+1\n",
    "                if not osp.exists('streamspot/'+str(scene)):\n",
    "                    os.system('mkdir streamspot/'+str(scene))\n",
    "                ff = open('streamspot/'+str(scene)+'/'+str(gId)+'.txt', 'w')\n",
    "            ff.write(str(row[0])+'\\t'+str(row[1])+'\\t'+str(row[2])+'\\t'+str(row[3])+'\\t'+str(row[4])+'\\t'+str(row[5])+'\\n')\n",
    "    os.system('rm all.tsv')\n",
    "    show('Done.')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def prepare_graph(df):\n",
    "    nodes, labels, edges = {}, {}, []\n",
    "    dummies = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7}\n",
    "\n",
    "    for _, row in df.iterrows():\n",
    "        actor_id, object_id = row[\"actorID\"], row[\"objectID\"]\n",
    "        action = row[\"action\"]\n",
    "\n",
    "        for entity_id in [actor_id, object_id]:\n",
    "            nodes.setdefault(entity_id, []).append(action)\n",
    "            if entity_id == actor_id:\n",
    "                labels[entity_id] = dummies[row['actor_type']]\n",
    "            else:\n",
    "                labels[entity_id] = dummies[row['object']]\n",
    "\n",
    "        edges.append((actor_id, object_id))\n",
    "\n",
    "    features, feat_labels, edge_index, mapping = [], [], [[], []], []\n",
    "    index_map = {}\n",
    "\n",
    "    for key, value in nodes.items():\n",
    "        index_map[key] = len(features)\n",
    "        features.append(value)\n",
    "        feat_labels.append(labels[key])\n",
    "        mapping.append(key)\n",
    "\n",
    "    for source, target in edges:\n",
    "        edge_index[0].append(index_map[source])\n",
    "        edge_index[1].append(index_map[target])\n",
    "\n",
    "    return features, feat_labels, edge_index, mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fmXWs1dKIzD8",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.nn import SAGEConv, GATConv\n",
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "class GCN(torch.nn.Module):\n",
    "    def __init__(self,in_channel,out_channel):\n",
    "        super().__init__()\n",
    "        self.conv1 = SAGEConv(in_channel, 32, normalize=True)\n",
    "        self.conv2 = SAGEConv(32, out_channel, normalize=True)\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = x.relu()\n",
    "        x = F.dropout(x, p=0.5, training=self.training)\n",
    "\n",
    "        x = self.conv2(x, edge_index)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YBuP_tSq94f4",
    "tags": []
   },
   "outputs": [],
   "source": [
    "def visualize(h, color):\n",
    "    z = TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())\n",
    "\n",
    "    plt.figure(figsize=(10,10))\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "\n",
    "    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap=\"Set2\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3PCP6SXwZaif",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from gensim.models.callbacks import CallbackAny2Vec\n",
    "import gensim\n",
    "from gensim.models import Word2Vec\n",
    "from multiprocessing import Pool\n",
    "from itertools import compress\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "\n",
    "class EpochSaver(CallbackAny2Vec):\n",
    "    '''Callback to save model after each epoch.'''\n",
    "\n",
    "    def __init__(self):\n",
    "        self.epoch = 0\n",
    "\n",
    "    def on_epoch_end(self, model):\n",
    "        model.save('trained_weights/streamspot/streamspot.model')\n",
    "        self.epoch += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "P8oBL8LFaeOf",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class EpochLogger(CallbackAny2Vec):\n",
    "    '''Callback to log information about training'''\n",
    "\n",
    "    def __init__(self):\n",
    "        self.epoch = 0\n",
    "\n",
    "    def on_epoch_begin(self, model):\n",
    "        print(\"Epoch #{} start\".format(self.epoch))\n",
    "\n",
    "    def on_epoch_end(self, model):\n",
    "        print(\"Epoch #{} end\".format(self.epoch))\n",
    "        self.epoch += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Se7Ei4tAapVj",
    "tags": []
   },
   "outputs": [],
   "source": [
    "logger = EpochLogger()\n",
    "saver = EpochSaver()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "if Train_Word2vec:\n",
    "    phrases = []\n",
    "    for i in range(50):\n",
    "        print(i)\n",
    "        f = open(f\"streamspot/{i}.txt\")\n",
    "        data = f.read().split('\\n')\n",
    "\n",
    "        data = [line.split('\\t') for line in data]\n",
    "        df = pd.DataFrame (data, columns = ['actorID', 'actor_type','objectID','object','action','timestamp'])\n",
    "        df = df.dropna()\n",
    "        docs,labels,edges,mapp = prepare_graph(df)\n",
    "        phrases = phrases + docs\n",
    "        \n",
    "    word2vec = Word2Vec(sentences=phrases, vector_size=30, window=10, min_count=1, workers=8,epochs=100,callbacks=[saver,logger])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "p3TAi69zI1bO",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.utils import class_weight\n",
    "import torch.nn.functional as F\n",
    "from torch.nn import CrossEntropyLoss\n",
    "\n",
    "model = GCN(30,8).to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Vn_pMyt5Jd-6",
    "tags": []
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import numpy as np\n",
    "from gensim.models import Word2Vec\n",
    "\n",
    "class PositionalEncoder:\n",
    "\n",
    "    def __init__(self, d_model, max_len=100000):\n",
    "        position = torch.arange(max_len).unsqueeze(1)\n",
    "        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n",
    "        self.pe = torch.zeros(max_len, d_model)\n",
    "        self.pe[:, 0::2] = torch.sin(position * div_term)\n",
    "        self.pe[:, 1::2] = torch.cos(position * div_term)\n",
    "\n",
    "    def embed(self, x):\n",
    "        return x + self.pe[:x.size(0)]\n",
    "\n",
    "def infer(document):\n",
    "    word_embeddings = [w2vmodel.wv[word] for word in document if word in  w2vmodel.wv]\n",
    "    \n",
    "    if not word_embeddings:\n",
    "        return np.zeros(20)\n",
    "\n",
    "    output_embedding = torch.tensor(word_embeddings, dtype=torch.float)\n",
    "    if len(document) < 100000:\n",
    "        output_embedding = encoder.embed(output_embedding)\n",
    "\n",
    "    output_embedding = output_embedding.detach().cpu().numpy()\n",
    "    return np.mean(output_embedding, axis=0)\n",
    "\n",
    "encoder = PositionalEncoder(30)\n",
    "w2vmodel = Word2Vec.load(\"trained_weights/streamspot/streamspot.model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 689309,
     "status": "ok",
     "timestamp": 1673566932746,
     "user": {
      "displayName": "Mati Ur Rehman",
      "userId": "04281203290774044297"
     },
     "user_tz": 300
    },
    "id": "Gclj6HVL17lD",
    "outputId": "f60fadea-a7db-471f-defe-fee744f6ef25",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torch_geometric import utils\n",
    "\n",
    "if Train_Gnn:\n",
    "    for i in range(300):\n",
    "        f = open(f\"streamspot/{i}.txt\")\n",
    "        data = f.read().split('\\n')\n",
    "\n",
    "        data = [line.split('\\t') for line in data]\n",
    "        df = pd.DataFrame (data, columns = ['actorID', 'actor_type','objectID','object','action','timestamp'])\n",
    "        df = df.dropna()\n",
    "        phrases,labels,edges,mapp = prepare_graph(df)\n",
    "\n",
    "        criterion = CrossEntropyLoss()\n",
    "\n",
    "        nodes = [infer(x) for x in phrases]\n",
    "        nodes = np.array(nodes)  \n",
    "\n",
    "        graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))\n",
    "\n",
    "        model.train()\n",
    "        optimizer.zero_grad() \n",
    "        out = model(graph.x, graph.edge_index) \n",
    "        loss = criterion(out, graph.y) \n",
    "        loss.backward() \n",
    "        optimizer.step() \n",
    "\n",
    "        _ , indices = out.sort(dim=1,descending=True)\n",
    "        pred = indices[:,0]\n",
    "        cond = pred == graph.y\n",
    "\n",
    "        print(cond.sum() / len(graph.y))\n",
    "\n",
    "        torch.save(model.state_dict(), f'trained_weights/streamspot/lstreamspot.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model.load_state_dict(torch.load(f'trained_weights/streamspot/lstreamspot.pth', map_location=torch.device('cpu')))\n",
    "model.eval()\n",
    "\n",
    "for i in range(400,450):\n",
    "    f = open(f\"streamspot/{i}.txt\")\n",
    "    data = f.read().split('\\n')\n",
    "\n",
    "    data = [line.split('\\t') for line in data]\n",
    "    df = pd.DataFrame (data, columns = ['actorID', 'actor_type','objectID','object','action','timestamp'])\n",
    "    df = df.dropna()\n",
    "    \n",
    "    phrases,labels,edges,mapp = prepare_graph(df)\n",
    "\n",
    "    nodes = [infer(x) for x in phrases]\n",
    "    nodes = np.array(nodes)\n",
    "    \n",
    "    graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))\n",
    "    graph.n_id = torch.arange(graph.num_nodes)\n",
    "    flag = torch.tensor([True]*graph.num_nodes, dtype=torch.bool)\n",
    "\n",
    "    out = model(graph.x, graph.edge_index)\n",
    "\n",
    "    sorted, indices = out.sort(dim=1,descending=True)\n",
    "    conf = (sorted[:,0] - sorted[:,1]) / sorted[:,0]\n",
    "    conf = (conf - conf.min()) / conf.max()\n",
    "\n",
    "    pred = indices[:,0]\n",
    "    cond = ~(pred == graph.y)\n",
    "    \n",
    "    print(cond.sum().item(), (cond.sum().item() / len(cond))*100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "thresh = 200\n",
    "correct_benign = 0\n",
    "correct_attack = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model.load_state_dict(torch.load(f'trained_weights/streamspot/lstreamspot.pth',map_location=torch.device('cpu')))\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for i in range(450,600):\n",
    "    f = open(f\"streamspot/{i}.txt\")\n",
    "    data = f.read().split('\\n')\n",
    "\n",
    "    data = [line.split('\\t') for line in data]\n",
    "    df = pd.DataFrame (data, columns = ['actorID', 'actor_type','objectID','object','action','timestamp'])\n",
    "    df = df.dropna()\n",
    "    \n",
    "    phrases,labels,edges,mapp = prepare_graph(df)\n",
    "\n",
    "    nodes = [infer(x) for x in phrases]\n",
    "    nodes = np.array(nodes)\n",
    "    \n",
    "    graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))\n",
    "    graph.n_id = torch.arange(graph.num_nodes)\n",
    "    flag = torch.tensor([True]*graph.num_nodes, dtype=torch.bool)\n",
    "\n",
    "    out = model(graph.x, graph.edge_index)\n",
    "\n",
    "    sorted, indices = out.sort(dim=1,descending=True)\n",
    "    conf = (sorted[:,0] - sorted[:,1]) / sorted[:,0]\n",
    "    conf = (conf - conf.min()) / conf.max()\n",
    "\n",
    "    pred = indices[:,0]\n",
    "    cond = ~(pred == graph.y)\n",
    "\n",
    "    if cond.sum() <= thresh:\n",
    "         correct_benign = correct_benign + 1\n",
    "    \n",
    "    print(cond.sum().item(), (cond.sum().item() / len(cond))*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for i in range(300,400):\n",
    "    f = open(f\"streamspot/{i}.txt\")\n",
    "    data = f.read().split('\\n')\n",
    "\n",
    "    data = [line.split('\\t') for line in data]\n",
    "    df = pd.DataFrame (data, columns = ['actorID', 'actor_type','objectID','object','action','timestamp'])\n",
    "    df = df.dropna()\n",
    "    \n",
    "    phrases,labels,edges,mapp = prepare_graph(df)\n",
    "  \n",
    "    nodes = [infer(x) for x in phrases]\n",
    "    nodes = np.array(nodes)\n",
    "    \n",
    "    graph = Data(x=torch.tensor(nodes,dtype=torch.float).to(device),y=torch.tensor(labels,dtype=torch.long).to(device), edge_index=torch.tensor(edges,dtype=torch.long).to(device))\n",
    "    graph.n_id = torch.arange(graph.num_nodes)\n",
    "    flag = torch.tensor([True]*graph.num_nodes, dtype=torch.bool)\n",
    "\n",
    "    out = model(graph.x, graph.edge_index)\n",
    "\n",
    "    sorted, indices = out.sort(dim=1,descending=True)\n",
    "    conf = (sorted[:,0] - sorted[:,1]) / sorted[:,0]\n",
    "    conf = (conf - conf.min()) / conf.max()\n",
    "\n",
    "    pred = indices[:,0]\n",
    "    cond = ~(pred == graph.y)\n",
    "\n",
    "    if cond.sum() > thresh:\n",
    "         correct_attack = correct_attack + 1\n",
    "            \n",
    "    print(cond.sum().item(), (cond.sum().item() / len(cond))*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "TOTAL_ATTACKS = 100\n",
    "TOTAL_BENIGN = 150\n",
    "\n",
    "def calculate_metrics(correct_attack, correct_benign):\n",
    "    TP = correct_attack\n",
    "    FP = TOTAL_BENIGN - correct_benign\n",
    "    TN = correct_benign\n",
    "    FN = TOTAL_ATTACKS - correct_attack\n",
    "\n",
    "    FPR = FP / (FP + TN) if (FP + TN) > 0 else 0\n",
    "    TPR = TP / (TP + FN) if (TP + FN) > 0 else 0\n",
    "\n",
    "    print(f\"Number of True Positives: {TP}\")\n",
    "    print(f\"Number of False Positives: {FP}\")\n",
    "    print(f\"Number of False Negatives: {FN}\")\n",
    "    print(f\"Number of True Negatives: {TN}\\n\")\n",
    "\n",
    "    precision = TP / (TP + FP) if (TP + FP) > 0 else 0\n",
    "    recall = TP / (TP + FN) if (TP + FN) > 0 else 0\n",
    "    print(f\"Precision: {precision}\")\n",
    "    print(f\"Recall: {recall}\")\n",
    "\n",
    "    fscore = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0\n",
    "    print(f\"Fscore: {fscore}\\n\")\n",
    "\n",
    "calculate_metrics(correct_attack, correct_benign)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "PyTorch 1.12.0",
   "language": "python",
   "name": "pytorch-1.12.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
